{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 综述区别\n",
    "llama3区别于llama2在模型层面的区别主要体现在全模型使用GQA。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import struct\n",
    "import inspect\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Optional, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型参数定义\n",
    "其中`@dataclass`参数是是`Python`中的一个装饰器，当你在类定义前加上该参数时，`Python`会自动为这个类生成一些常用的特殊方法，比如 `__init__()`、`__repr__()`、`__eq__()` 等。\n",
    "- `__init__()`主要体现在不用写`class`后带一堆参数，然后再`self.xx=x`这一坨代码了,\n",
    "- 还有个好处是`__repr__()`支持打印该类，具体有哪些参数,\n",
    "- `__eq__()`可以用bool判断两类里面的`parameter`是否相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class ModelArgs:\n",
    "    dim: int = 4096\n",
    "    n_layers: int = 6\n",
    "    n_heads: int = 6\n",
    "    n_group: Optional[int] = 3\n",
    "    vocab_size: int = 4096\n",
    "    hidden_dim: Optional[int] = None\n",
    "    multiple_of: int = 256  # MLP层隐层维度的指定计算参数(见FFN层)\n",
    "    norm_eps: float = 1e-5\n",
    "    max_seq_len: int = 2048\n",
    "    dropout: float = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RMS正则化，在Qwen-blog已讲。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RMSNorm(torch.nn.Module):\n",
    "    def __init__(self, dim: int, eps: float):\n",
    "        super().__init__()\n",
    "        self.eps = eps\n",
    "        self.weight = nn.Parameter(torch.ones(dim))\n",
    "\n",
    "    def _norm(self, x):\n",
    "        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = self._norm(x.float()).type_as(x)\n",
    "        return output * self.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ROPE相关，在Qwen-blog已讲"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):\n",
    "    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n",
    "    t = torch.arange(end, device=freqs.device) \n",
    "    freqs = torch.outer(t, freqs).float() \n",
    "    freqs_cos = torch.cos(freqs) \n",
    "    freqs_sin = torch.sin(freqs) \n",
    "    return freqs_cos, freqs_sin\n",
    "\n",
    "def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n",
    "    ndim = x.ndim\n",
    "    assert 0 <= 1 < ndim\n",
    "    assert freqs_cis.shape == (x.shape[1], x.shape[-1])\n",
    "    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]\n",
    "    return freqs_cis.view(shape)\n",
    "\n",
    "def apply_rotary_emb(\n",
    "    xq: torch.Tensor,\n",
    "    xk: torch.Tensor,\n",
    "    freqs_cos: torch.Tensor,\n",
    "    freqs_sin: torch.Tensor\n",
    ") -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "\n",
    "    # 重塑 xq 和 xk，使其与复数表示相匹配\n",
    "    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)\n",
    "    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)\n",
    "\n",
    "    # 重塑形为了广播\n",
    "    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)\n",
    "    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)\n",
    "\n",
    "    # 应用旋转嵌入\n",
    "    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin\n",
    "    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos\n",
    "    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin\n",
    "    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos\n",
    "\n",
    "    # 讲最后两维度拉平。\n",
    "    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)\n",
    "    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)\n",
    "\n",
    "    return xq_out.type_as(xq), xk_out.type_as(xk)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义输入x， n_rep是需要重复的次数，在这里一般是组数\n",
    "def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
    "\n",
    "    bs, slen, n_kv_heads, head_dim = hidden_states.shape\n",
    "    # dont need repeat here means multi head attention\n",
    "    if n_rep == 1:\n",
    "        return hidden_states\n",
    "    # first we expand x to (bs, seq_len, head, group, head_dim)\n",
    "    hidden_states = hidden_states[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim)\n",
    "    # reshape make head -> head * group\n",
    "    return hidden_states.reshape(bs, slen, n_kv_heads * n_rep, head_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attention的处理与Qwen-blog大差不差.\n",
    "### 在此我将flash-attn给抹掉了，手工搓计算流程更易于学习者，欢迎大佬贡献flash-attn细节教程！！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Attention(nn.Module):\n",
    "    def __init__(self, args: ModelArgs):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.group = args.n_group\n",
    "        self.heads = args.n_heads\n",
    "        self.kv_heads = args.n_heads // args.n_group\n",
    "        assert args.n_heads % self.kv_heads == 0\n",
    "        self.head_dim = args.dim // args.n_heads\n",
    "        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)\n",
    "        self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)\n",
    "        self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)\n",
    "        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)\n",
    "        self.attn_dropout = nn.Dropout(args.dropout)\n",
    "        self.resid_dropout = nn.Dropout(args.dropout)\n",
    "        self.dropout = args.dropout\n",
    "        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float(\"-inf\"))\n",
    "        mask = torch.triu(mask, diagonal=1)\n",
    "        self.register_buffer(\"mask\", mask)\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        x: torch.Tensor,\n",
    "        freqs_cos: torch.Tensor,\n",
    "        freqs_sin: torch.Tensor,\n",
    "    ):\n",
    "        bsz, seqlen, _ = x.shape\n",
    "\n",
    "        # QKV\n",
    "        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)\n",
    "        xq = xq.view(bsz, seqlen, self.heads, self.head_dim)\n",
    "        xk = xk.view(bsz, seqlen, self.kv_heads, self.head_dim)\n",
    "        xv = xv.view(bsz, seqlen, self.kv_heads, self.head_dim)\n",
    "\n",
    "        # RoPE relative positional embeddings\n",
    "        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)\n",
    "\n",
    "        # grouped multiquery attention: expand out keys and values\n",
    "        xk = repeat_kv(xk, self.group)  # (bs, seqlen, n_local_heads, head_dim)\n",
    "        xv = repeat_kv(xv, self.group)  # (bs, seqlen, n_local_heads, head_dim)\n",
    "\n",
    "        # make heads into a batch dimension\n",
    "        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)\n",
    "        xk = xk.transpose(1, 2)\n",
    "        xv = xv.transpose(1, 2)\n",
    "\n",
    "\n",
    "        # 先不使用flash attn，从零走一遍流程！\n",
    "        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
    "        assert hasattr(self, 'mask')\n",
    "        scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)\n",
    "        scores = F.softmax(scores.float(), dim=-1).type_as(xq)\n",
    "        scores = self.attn_dropout(scores)\n",
    "        output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)\n",
    "\n",
    "        # restore time as batch dimension and concat heads\n",
    "        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)\n",
    "\n",
    "        # 最终送入output层并正则，得到最终结果。\n",
    "        output = self.wo(output)\n",
    "        output = self.resid_dropout(output)\n",
    "        return output\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FFN网络，其本质是一个MLP，经过线性变换，与Qwen-blog一致。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):\n",
    "        super().__init__()\n",
    "        if hidden_dim is None:\n",
    "            hidden_dim = 4 * dim\n",
    "            hidden_dim = int(2 * hidden_dim / 3)\n",
    "            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)\n",
    "        self.w1 = nn.Linear(dim, hidden_dim, bias=False)\n",
    "        self.w2 = nn.Linear(hidden_dim, dim, bias=False)\n",
    "        self.w3 = nn.Linear(dim, hidden_dim, bias=False)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Block就是将上述模块组合到一起，从而形成最终的decoder-layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, layer_id: int, args: ModelArgs):\n",
    "        super().__init__()\n",
    "        self.n_heads = args.n_heads\n",
    "        self.dim = args.dim\n",
    "        self.head_dim = args.dim // args.n_heads\n",
    "        self.attention = Attention(args)\n",
    "        self.feed_forward = FeedForward(\n",
    "            dim=args.dim,\n",
    "            hidden_dim=args.hidden_dim,\n",
    "            multiple_of=args.multiple_of,\n",
    "            dropout=args.dropout,\n",
    "        )\n",
    "        self.layer_id = layer_id\n",
    "        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)\n",
    "        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)\n",
    "\n",
    "    def forward(self, x, freqs_cos, freqs_sin):\n",
    "        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)\n",
    "        out = h + self.feed_forward.forward(self.ffn_norm(h))\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Transformer(nn.Module):\n",
    "    last_loss: Optional[torch.Tensor]\n",
    "\n",
    "    def __init__(self, params: ModelArgs):\n",
    "        super().__init__()\n",
    "        self.params = params\n",
    "        self.vocab_size = params.vocab_size\n",
    "        self.n_layers = params.n_layers\n",
    "\n",
    "        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)  # 其weight形状为(vocab,dim)\n",
    "        self.dropout = nn.Dropout(params.dropout)\n",
    "        self.layers = torch.nn.ModuleList()\n",
    "        for layer_id in range(params.n_layers):\n",
    "            self.layers.append(TransformerBlock(layer_id, params))\n",
    "        self.norm = RMSNorm(params.dim, eps=params.norm_eps)\n",
    "        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)  # 维数也为(vocab,dim)--x·W^T\n",
    "\n",
    "        # 将模型的嵌入层（embedding layer）和输出层（unembedding layer）的权重共享，即 \"权重共享\" 或 \"weight tying\"\n",
    "        self.tok_embeddings.weight = self.output.weight # 来源论文: https://paperswithcode.com/method/weight-tying\n",
    "\n",
    "        # some useful precompute for the RoPE relative positional embeddings\n",
    "        freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)\n",
    "        self.register_buffer(\"freqs_cos\", freqs_cos, persistent=False)\n",
    "        self.register_buffer(\"freqs_sin\", freqs_sin, persistent=False)\n",
    "\n",
    "        # init all weights\n",
    "        self.apply(self._init_weights)\n",
    "        # apply special scaled init to the residual projections, per GPT-2 paper\n",
    "        for pn, p in self.named_parameters():\n",
    "            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):\n",
    "                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))\n",
    "\n",
    "        # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.\n",
    "        self.last_loss = None\n",
    "\n",
    "    def _init_weights(self, module):\n",
    "        if isinstance(module, nn.Linear):\n",
    "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
    "            if module.bias is not None:\n",
    "                torch.nn.init.zeros_(module.bias)\n",
    "        elif isinstance(module, nn.Embedding):\n",
    "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
    "\n",
    "    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:\n",
    "        _bsz, seqlen = tokens.shape\n",
    "        h = self.tok_embeddings(tokens)\n",
    "        h = self.dropout(h)\n",
    "        freqs_cos = self.freqs_cos[:seqlen]\n",
    "        freqs_sin = self.freqs_sin[:seqlen]\n",
    "\n",
    "        for layer in self.layers:\n",
    "            # print('loging')\n",
    "            h = layer(h, freqs_cos, freqs_sin)\n",
    "        h = self.norm(h)\n",
    "\n",
    "        if targets is not None:\n",
    "            # 有targets则计算loss\n",
    "            logits = self.output(h)\n",
    "            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)\n",
    "        else:\n",
    "            # 在推理阶段，只抽取最后一行--预测下一个token即可\n",
    "            logits = self.output(h[:, [-1], :]).reshape(_bsz,-1) # note: using list [-1] to preserve the time dim\n",
    "            self.last_loss = None\n",
    "\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 运行测试一下！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = ModelArgs()\n",
    "model = Transformer(conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 7.0483e-01, -9.6776e-01,  9.7280e-01,  ..., -1.2867e-01,\n",
       "          1.4969e+00,  4.8647e-01],\n",
       "        [-5.8376e-01,  1.7032e+00, -9.8937e-01,  ...,  1.1173e+00,\n",
       "         -2.8729e-05, -2.1118e-02],\n",
       "        [ 2.5072e+00, -1.5984e+00, -1.7395e+00,  ...,  1.0265e+00,\n",
       "         -6.7715e-01, -2.3375e-01],\n",
       "        ...,\n",
       "        [ 3.8787e-01, -5.7305e-01,  1.7864e+00,  ...,  3.1860e-02,\n",
       "         -1.0664e-01,  2.7426e+00],\n",
       "        [-2.3487e+00,  6.1362e-02, -1.9033e-01,  ..., -1.4094e-01,\n",
       "         -1.8798e-01,  2.4475e-01],\n",
       "        [-5.5272e-01,  6.0121e-01, -3.0820e-01,  ..., -6.4737e-01,\n",
       "         -9.3688e-01, -4.5705e-01]], grad_fn=<ViewBackward0>)"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = torch.randint(low=10, high=100, size=(10,50))\n",
    "output = model(inputs)\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 4096])"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xhr",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
